#!/usr/bin/env python

"Turns the subset of Python we support into virtual bytecode for Java code to run"

# We use Python's ast module for the heavy lifting.  It parses the
# target file and we then operate on the resulting abstract syntax
# tree.

import sys, os
import ast
import linecache
import array
import re
import functools

py3=sys.version_info>=(3,0)

if py3:
    sys.exit("Currently jmp-compile only runs under Python 2.")

if py3:
    basestring=str

class Label:
    "Instances used to mark locations while generating byte code"
    def __repr__(self):
        return "Label_"+str(id(self))

# We can't use print()
def write(*args, **kwargs):
    end=kwargs.pop("end", "\n")
    file=kwargs.pop("file", sys.stdout)
    sep=kwargs.pop("sep", " ")
    if kwargs:
        raise Exception("Unexpected args "+repr(kwargs))
    file.write(sep.join(str(x) if not isinstance(x,basestring) else x for x in args)+end)

# Returns nice errors when an exception happens processing a node
def errorfy(func):
    @functools.wraps(func)
    def wrapped(self, node):
        try:
            return func(self, node)
        except Exception:
            t,e,_=sys.exc_info()
            if isinstance(t, Compiler.Error) or self.options.traceback:
                raise
            raise Compiler.Error(node, str(e))
    return wrapped

class Compiler(ast.NodeVisitor):

    # ast classes to string opcodes
    _opmap={
        ast.Add: "ADD",
        ast.Sub: "SUB",
        ast.Mult: "MULT",
        ast.Div: "DIV",
        ast.FloorDiv: "DIV",
        ast.Mod: "MOD",

        ast.Lt: "LT",
        ast.LtE: "LTE",
        ast.Eq: "EQ",
        ast.NotEq: "NOT_EQ",
        ast.Gt: "GT",
        ast.GtE: "GTE",

        ast.In: "IN",
        ast.Is: "IS",

        ast.UAdd: "UNARY_ADD",
        ast.USub: "UNARY_NEG",
        ast.Not: "NOT",

        ast.And: "AND",
        ast.Or: "OR",
        }

    def get_op(self, node, *op):
        op=op[0] if op else node.op
        try:
            return self._opmap[op.__class__]
        except KeyError:
            raise self.Error(node, op.__class__.__name__+" not supported")

    class Error(Exception):
        def __init__(self, node, message):
            Exception.__init__(self, message)
            self.message=message
            self.node=node

    # main method to turn .py into .jmp
    def compile(self, options, inputfile, outputfile):
        self.inputfile=inputfile
        self.options=options
        self.bcstack=[None]
        self.counter=0
        flags=ast.PyCF_ONLY_AST
        if options.print_function:
            import __future__
            flags|=__future__.print_function.compiler_flag

        try:
            root=compile(open(inputfile, "rb").read(), inputfile, 'exec', flags)
        except:
            t,e,_=sys.exc_info()
            if not isinstance(e, SyntaxError):
                raise
            write("SyntaxError:", e.msg,"at line",e.lineno,"of", inputfile, file=sys.stderr)
            write(linecache.getline(inputfile, e.lineno), end="", file=sys.stderr)
            if e.offset is not None:
                write ((e.offset-1 if e.offset else 0)*" "+"^", file=sys.stderr)
            sys.exit(2)

        ast.fix_missing_locations(root)

        res=[]
        try:
            # Replace x+=1 with x=x+1
            aug=AugmentedAssigns()
            aug.options=options
            root=aug.visit(root)
            ast.fix_missing_locations(root)

            # Constant expansions
            if options.constants:
                const=ConstantExpander()
                const.options=options
                root=const.visit(root)
                ast.fix_missing_locations(root)

            # Removal of dead code, static expression evaluation etc
            if options.optimization:
                opty=Optimizer()
                opty.options=options
                root=opty.visit(root)
                ast.fix_missing_locations(root)

            # Regular compilation
            for node,v  in self.visit(root):
                res.append( (node, v) )

            # string and location tables
            strings=[]
            locations={}
            linenotab=[]
            lastlineno=-1
            pos=0
            out=[]
            for n,v in res:
                if isinstance(v, Label):
                    locations[v]=pos
                    continue
                if n is not None and hasattr(n, "lineno") and lastlineno!=n.lineno:
                    linenotab.append( (pos, n.lineno) )
                    lastlineno=n.lineno
                if not isinstance(v, tuple):
                    if v not in opcode_map:
                        write("Internal error: unknown opcode \"%s\"" % (v,), file=sys.stderr)
                        sys.exit(3)
                    assert opcode_map[v]<128
                    out.append( (n,v) )
                    pos+=1
                    continue
                pos+=3
                op, val=v
                if op not in opcode_map:
                    write ("Internal error: unknown opcode \"%s\"" % (op,), file=sys.stderr)
                    sys.exit(3)

                opcode=opcode_map[op]

                if is_string_op(opcode):
                    assert isinstance(val, basestring)
                    if val not in strings:
                        if len(val)>65535:
                            raise self.Error(n, "String is too long: "+str(len(val)))
                        strings.append(val)
                    v=(op, strings.index(val))
                    out.append( (n, v) )
                elif is_code_op(opcode):
                    assert isinstance(val, Label)
                    out.append( (n, v) )
                elif op=="PUSH_INT":
                    # turn into two operations depending on value size
                    if val>=0 and val<=0xffff:
                        out.append( (n, v) )
                    elif val<0 and val>=-0xffff:
                        out.append( (n, ("PUSH_INT", -val)) )
                        out.append( (n, "UNARY_NEG") )
                        pos+=1
                    else:
                        out.append( (n, ("PUSH_INT", val&0xffff)) )
                        out.append( (n, ("PUSH_INT_HI", (val>>16)&0xffff)) )
                        pos+=3
                elif is_int_op(opcode):
                    assert val<0xffff
                    out.append( (n, (op, val)) )
                else:
                    assert False, "Unhandled opcode "+op

            # Generate output
            res=out
            codelen=pos
            if codelen==1:
                write ("There is no executable code in the source file", file=sys.stderr)
                sys.exit(3)
            if codelen>65535:
                write ("Your code is too long - %d instructions" % (codelen,), file=sys.stderr)
                sys.exit(3)
            if linenotab[-1][1]>65535:
                write("Too many lines of code - %d" % (linenotab[-1][1],), file=sys.stderr)
                sys.exit(3)

            # fixup labels
            out=[]
            for n,v in res:
                if isinstance(v,tuple) and isinstance(v[1], Label):
                    v=v[0],locations[v[1]]
                out.append( (n,v) )
            res=out

            ### final output - this could really use a bytes type
            out=array.array("B")
            assert out.itemsize==1, "Bytes must be one byte in length!"

            # version
            out.append(0)
            out.append(0)

            # string table
            if len(strings)>65535:
                write ("There are too many strings (%d) - maximum is %d" % (len(strings), 65535), file=sys.stderr)
                sys.exit(3)
            out.append(len(strings)&0xff)
            out.append(len(strings)>>8)
            for s in strings:
                s=s.encode("utf8")
                if len(s)>65535:
                    write ("One of your strings is too long - %d bytes" % (len(s),), file=sys.stderr)
                    sys.exit(3)
                out.append(len(s)&0xff)
                out.append(len(s)>>8)
                if py3: # s is now bytes
                    out.extend(s)
                else:
                    out.extend(ord(x) for x in s)

            # line number table
            if options.line_table:
                if len(linenotab)>65535:
                    write("There are too many lines with code (%d) - maximum is %d" % (len(linenotab), 65535), file=sys.stderr)
                    sys.exit(3)
                out.append(len(linenotab)&0xff)
                out.append(len(linenotab)>>8)
                for x,y in linenotab:
                    out.append( x & 0xff)
                    out.append( x >> 8 )
                    out.append( y & 0xff)
                    out.append( y >> 8 )
            else:
                out.append(0)
                out.append(0)

            # The code
            out.append(codelen & 0xff)
            out.append(codelen >> 8 )
            startsize=len(out) # sanity check code size
            for n,v in res:
                if isinstance(v, tuple):
                    out.append(self.opcode_for(v[0]))
                    out.append(v[1] & 0xff)
                    out.append(v[1] >> 8)
                else:
                    out.append(self.opcode_for(v))
            assert len(out)-startsize==codelen, "Internal error: code actual size doesn't match calculated size"

            # Write it
            if options.jmpoutput:
                with open(outputfile, "wb") as f:
                    out.tofile(f)

                if options.annotate:
                    options.dump_source=inputfile
                    with open(os.path.splitext(outputfile)[0]+".i", "wt") as outf:
                        dump_internal(options, strings, linenotab, out[-codelen:], outf)

        except:
            et,ev,_=sys.exc_info()
            if not isinstance(ev, self.Error):
                raise
            if options.traceback:
                x=sys.exc_info()
                import traceback
                traceback.print_exception(*x)

            # Slice objects don't have line/column set so we use the
            # most recent item emitted.  If you tried really hard you
            # could also make that a Slice giving an error here.
            ast.fix_missing_locations(root)
            lineno=getattr(ev.node, "lineno", res[-1][0].lineno if res else 1)
            column=getattr(ev.node, "col_offset", res[-1][0].col_offset if res else 1)
            write (ev.message,"at line", lineno, "of", inputfile, file=sys.stderr)
            write (linecache.getline(inputfile, lineno)+column*" "+"^", file=sys.stderr)
            sys.exit(2)


    def opcode_for(self, opcode):
        try:
            return opcode_map[opcode]
        except KeyError:
            write ("Internal error.  No opcode for \"%s\"" % (opcode,), file=sys.stderr)
            sys.exit(5)

    def generic_visit(self, node):
        raise self.Error(node, "Unsupported syntax "+node.__class__.__name__)

    def visit_list(self, node):
        # This happens when the optimizer replaces an always True if
        # statement with its body
        for n in node:
            for x in self.visit(n):
                yield x

    def visit_Module(self, node):
        for i in node.body:
            assert i is not None
            for x in self.visit(i):
                yield x
        yield None, "EXIT_LOOP"

    def visit_Pass(self, node):
        # needed to force the method to be a generator
        if False:
            yield None

    def visit_Expr(self, node):
        # docstrings etc
        if isinstance(node.value, (ast.Str, ast.Num)):
            return
        for x in self.visit(node.value):
            yield x
        yield node, "POP_TOP"

    def visit_BoolOp(self, node):
        end=Label()
        for i,v in enumerate(node.values):
            for x in self.visit(v):
                yield x
            if i!=len(node.values)-1:
                yield node, (self.get_op(node), end)
        yield None, end

    def visit_UnaryOp(self, node):
        for x in self.visit(node.operand):
            yield x
        yield node, self.get_op(node)

    def visit_BinOp(self, node):
        assert len(node.op._fields)==0
        for x in self.visit(node.left):
            yield x
        for x in self.visit(node.right):
            yield x
        yield node, self.get_op(node)

    def visit_Compare(self, node):
        if len(node.ops)!=1:
            raise self.Error(node, "Only one comparison is supported (you can't do x<y<z)")
        for x in self.visit(node.left):
            yield x
        for x in self.visit(node.comparators[0]):
            yield x
        if isinstance(node.ops[0], ast.NotIn):
            yield node, "IN"
            yield node, "NOT"
        elif isinstance(node.ops[0], ast.IsNot):
            yield node, "IS"
            yield node, "NOT"
        else:
            yield node, self.get_op(node, node.ops[0])

    def visit_Num(self, node):
        if not isinstance(node.n, int):
            raise self.Error(node, "Only integers are supported")
        if node.n>0x7fffffff or node.n<-0x80000000:
            raise self.Error(node, "Value %d is too large" % node.n)
        yield node, ("PUSH_INT", node.n)

    def visit_Str(self, node):
        yield node, ("PUSH_STR", node.s)

    def visit_Name(self, node):
        assert len(node.ctx._fields)==0
        # we don't allow true/false/none to be redefined
        if node.id=="True":
            yield node, "PUSH_TRUE"
        elif node.id=="False":
            yield node, "PUSH_FALSE"
        elif node.id=="None":
            yield node, "PUSH_NONE"
        else:
            yield node, ("LOAD_NAME", node.id)

    def visit_Global(self, node):
        for name in node.names:
            yield node, ("GLOBAL", name)

    def visit_Attribute(self, node):
        assert len(node.ctx._fields)==0
        yield node, ("PUSH_STR", node.attr)
        for x in self.visit(node.value):
            yield x
        yield node, "ATTR"

    def visit_Subscript(self, node):
        assert len(node.ctx._fields)==0

        for x in self.visit(node.value):
            yield x

        for x in self.visit(node.slice):
            yield x

        if isinstance(node.slice, ast.Index):
            yield node, "SUBSCRIPT"
        else:
            yield node, "SUBSCRIPT_SLICE"

    def visit_Index(self, node):
        for x in self.visit(node.value):
            yield x

    def visit_Slice(self, node):
        if node.step is not None:
            raise self.Error(node, "Steps not supported in slices")
        if node.lower is None:
            yield None, "PUSH_NONE"
        else:
            for x in self.visit(node.lower):
                yield x
        if node.upper is None:
            yield None, "PUSH_NONE"
        else:
            for x in self.visit(node.upper):
                yield x

    def visit_List(self, node):
        assert len(node.ctx._fields)==0
        for x in node.elts:
            for y in self.visit(x):
                yield y
        yield node, ("PUSH_INT", len(node.elts))
        yield node, "LIST"

    visit_Tuple=visit_List

    def visit_Dict(self, node):
        for k,v in zip(node.keys, node.values):
            # Note we visit the value before the key.
            # This is to match what Python does.
            for x in self.visit(v):
                yield x
            for x in self.visit(k):
                yield x
        yield node, ("PUSH_INT", len(node.keys))
        yield node, "DICT"

    def visit_Assign(self, node):
        if len(node.targets)!=1:
            raise self.Error(node, "Only assigning to one item at a time is supported")
        target=node.targets[0]
        if not isinstance(target, (ast.Name, ast.Subscript, ast.Attribute)):
            raise self.Error(target, "That kind of assignment is not supported")
        if isinstance(target, ast.Name):
            for x in self.visit(node.value):
                yield x
            if target.id in ("True", "False", "None"):
                raise self.Error(target, "You can't assign to constants")
            yield node, ("STORE_NAME", target.id)
        elif isinstance(target, ast.Attribute):
            # python evaluation order is rhs then lhs
            for x in self.visit(node.value):
                yield x
            for x in self.visit(target.value):
                yield x
            yield node, ("STORE_ATTR_NAME", target.attr)
        else:
            if not isinstance(target.slice, ast.Index):
                raise self.Error(node, "Only simple array/list subscript assignment supported")
            for x in self.visit(target.value):
                yield x
            for x in self.visit(target.slice.value):
                yield x
            for x in self.visit(node.value):
                yield x
            yield node, "ASSIGN_INDEX"

    def visit_IfExp(self, node):
        falsearm=Label()
        end=Label()
        for x in self.visit(node.test):
            yield x
        yield node, ("IF_FALSE", falsearm)
        for x in self.visit(node.body):
            yield x
        yield None, ("GOTO", end)
        yield None, falsearm
        for x in self.visit(node.orelse):
            yield x
        yield None, end

    def visit_If(self, node):
        for x in self.visit(node.test):
            yield x
        end=Label()
        elsepart=Label() if node.orelse else None
        yield node, ("IF_FALSE", elsepart if elsepart else end)
        for x in node.body:
            assert x is not None
            for y in self.visit(x):
                yield y
        if elsepart:
            yield node, ("GOTO", end)
            yield node, elsepart
            for x in node.orelse:
                assert x is not None
                for y in self.visit(x):
                    yield y
        yield node, end

    def visit_While(self, node):
        if node.orelse:
            raise self.Error(node, "else not supported for while")
        test=Label()
        end=Label()
        self.bcstack.append( (end, test, "while") )
        yield node, test
        for x in self.visit(node.test):
            yield x
        yield node, ("IF_FALSE", end)
        for x in node.body:
            assert x is not None
            for y in self.visit(x):
                yield y
        yield None, ("GOTO", test)
        yield node, end
        self.bcstack.pop()

    def visit_For(self, node):
        if node.orelse:
            raise self.Error(node, "else not supported for 'for'")
        if not isinstance(node.target, ast.Name):
            raise self.Error(node.target, "Can only use single name for 'for' variable")

        loop=Label()
        end=Label()
        self.bcstack.append( (end, loop, "for") )
        for x in self.visit(node.iter):
            yield x
        yield node, "ITER"
        yield node, loop
        yield node, ("NEXT", end)
        yield node.target, ("STORE_NAME", node.target.id)
        for x in node.body:
            assert x is not None
            for y in self.visit(x):
                yield y
        yield None, ("GOTO", loop)
        yield None, end
        # get rid of iterator
        yield None, "POP_TOP"
        self.bcstack.pop()

    def visit_Break(self, node):
        bc=self.bcstack[-1]
        if bc is None:
            raise self.Error(node, "'break' not relevant here")
        yield node, ("GOTO", bc[0])

    def visit_Continue(self, node):
        bc=self.bcstack[-1]
        if bc is None:
            raise self.Error(node, "'continue' not relevant here")
        yield node, ("GOTO", bc[1])

    def visit_ListComp(self, node):
        if not hasattr(self, "list_comp_nesting"):
            self.list_comp_nesting=-1
        self.list_comp_nesting+=1
        name="_l[%d]" % (self.list_comp_nesting)
        yield node, ("PUSH_INT", 0)
        yield node, "LIST"
        yield node, ("STORE_NAME", name)
        stack=[]
        for gen in node.generators:
            if not isinstance(gen.target, ast.Name):
                raise self.Error(gen.target, "Can only use single name for 'for' variable")
            loop=Label()
            end=Label()
            stack.append( (loop, end) )
            for x in self.visit(gen.iter):
                yield x
            yield gen, "ITER"
            yield gen, loop
            yield gen, ("NEXT", end)
            yield gen, ("STORE_NAME", gen.target.id)

            for cond in gen.ifs:
                for x in self.visit(cond):
                    yield x
                yield cond, ("IF_FALSE", loop)

            if gen is node.generators[-1]:
                # we are adding this one - construct call to append
                for x in self.visit(node.elt):
                    yield x
                yield gen, ("PUSH_INT", 1) # nargs
                yield gen, ("PUSH_STR", "append")
                yield gen, ("LOAD_NAME", name)
                yield gen, "ATTR"
                yield gen, "CALL"
                yield gen, "POP_TOP"

        for loop, end in stack[::-1]:
            yield None, ("GOTO", loop)
            yield None, end
            yield None, "POP_TOP"
        # leave list on stack and get rid of our tempname
        yield node, ("LOAD_NAME", name)
        yield node, ("DEL_NAME", name)
        self.list_comp_nesting-=1

    def visit_Print(self, node):
        if node.dest is not None:
            raise self.Error(node, "print targets are not supported")
        for x in node.values:
            for y in self.visit(x):
                yield y
        yield node, "PUSH_TRUE" if node.nl else "PUSH_FALSE"
        yield node, ("PUSH_INT", len(node.values))
        yield node, "PRINT"

    def visit_Assert(self, node):
        if self.options.asserts:
            end=Label()
            for x in self.visit(node.test):
                yield x
            yield node, "NOT"
            yield node, ("IF_FALSE", end)
            if node.msg is not None:
                for x in self.visit(node.msg):
                    yield x
            else:
                yield node, ("PUSH_STR", Printer().as_string(self.inputfile, node.test))
            yield node, "ASSERT_FAILED"
            yield None, end

    def visit_FunctionDef(self, node):
        # this serves both functions and lamdba
        self.bcstack.append(None)
        if getattr(node, "decorator_list", None):
            raise self.Error(node, "Decorators are not supported")
        args=node.args
        if getattr(args, "defaults", None):
            raise self.Error(node, "Default arguments are not supported")
        if getattr(args, "kwarg", None):
            raise self.Error(node, "Keyword arguments are not supported")
        if getattr(args, "vararg", None):
            raise self.Error(node, "Varargs (*args) are not supported")
        body=Label()
        end=Label()
        yield node, ("MAKE_METHOD", body)
        if not isinstance(node, ast.Lambda):
            yield node, ("STORE_NAME", node.name)
        yield node, ("GOTO", end)
        yield node, body
        for i, n in enumerate(args.args):
            if not isinstance(n, ast.Name):
                raise self.Error(n, "Only single named arguments are supported (#%d)" % (i,))
        yield node, ("PUSH_INT", len(args.args))
        yield node, "FUNCTION_PROLOG"
        for n in reversed(args.args):
            yield n, ("STORE_NAME", n.id)
        last=None
        if isinstance(node, ast.Lambda):
            for x in self.visit(node.body):
                yield x
            yield None, "RETURN"
        else:
            for x in node.body:
                assert x is not None
                for y in self.visit(x):
                    last=y[1]
                    yield y
            if last!="RETURN":
                yield None, "PUSH_NONE"
                yield None, "RETURN"
        yield node, end
        self.bcstack.pop()

    visit_Lambda=visit_FunctionDef

    def visit_Call(self, node):
        if node.starargs:
            raise self.Error(node, "*args not supported")
        if node.kwargs or node.keywords:
            raise self.Error(node, "Keyword arguments are not supported")
        for x in node.args:
            for y in self.visit(x):
                yield y
        yield node, ("PUSH_INT", len(node.args))
        for x in self.visit(node.func):
            yield x
        yield node, "CALL"

    def visit_Return(self, node):
        # pop iterators off the stack first
        icount=0
        for i in self.bcstack[::-1]:
            if i is None:
                break
            if i[2]=="for":
                icount+=1
        if icount==1:
            yield node, "POP_TOP"
        elif icount:
            yield node, ("POP_N", icount)
        # now return
        if node.value is None:
            yield node, "PUSH_NONE"
        else:
            for x in self.visit(node.value):
                yield x
        yield node, "RETURN"

    def visit_Delete(self, node):
        for target in node.targets:
            if isinstance(target, ast.Name):
                yield target, ("DEL_NAME", target.id)
            elif isinstance(target, ast.Subscript):
                for x in self.visit(target.value):
                    yield x
                for x in self.visit(target.slice):
                    yield x
                if isinstance(target.slice, ast.Index):
                    yield target, "DEL_INDEX"
                else:
                    yield target, "DEL_SLICE"
            else:
                raise self.Error(target, "Unsupported item to delete")


class Optimizer(ast.NodeTransformer):
    "Perform some optimizations (eg folding constant expressions, removing dead code)"

    # Use this to iterate of lists (eg if node body)
    def process_list(self, l):
        if not l:
            return
        res=[]
        for item in l:
            r=self.visit(item)
            if r is not None:
                res.append(r)
        l[:]=res

    def is_constant(self, node):
        if isinstance(node, (ast.Str, ast.Num)):
            return True
        if isinstance(node, ast.Name) and isinstance(node.ctx, ast.Load):
            if node.id in ("None", "True", "False"):
                return True
        return False

    @errorfy
    def get_val(self, node):
        assert self.is_constant(node)
        if isinstance(node, ast.Str):
            return node.s
        if isinstance(node, ast.Num):
            return node.n
        if isinstance(node, ast.Name):
            return {
                "None": None,
                "True": True,
                "False": False}[node.id]
        raise Exception("Can't get_val "+repr(node))

    def nodify(self, sourcenode, value):
        if value is True:
            v=ast.Name("True", ast.Load())
        elif value is False:
            v=ast.Name("False", ast.Load())
        elif value is None:
            v=ast.Name("None", ast.Load())
        elif isinstance(value, int):
            v=ast.Num(value)
        elif isinstance(value, basestring):
            v=ast.Str(value)
        if v is not None:
            return ast.copy_location(v, sourcenode)

        raise Exception("Can't nodify "+repr(value))

    @errorfy
    def visit_BinOp(self, node):
        node.left=self.visit(node.left)
        node.right=self.visit(node.right)
        if self.is_constant(node.left) and self.is_constant(node.right):
            l=self.get_val(node.left)
            r=self.get_val(node.right)
            if isinstance(node.op, ast.Add):
                return self.nodify(node, l+r)
            elif isinstance(node.op, ast.Sub):
                return self.nodify(node, l-r)
            elif isinstance(node.op, ast.Mult):
                return self.nodify(node, l*r)
            elif isinstance(node.op, ast.Mod):
                return self.nodify(node, l%r)
            elif isinstance(node.op, (ast.Div, ast.FloorDiv)):
                if l==-2147483648 and r==-1:
                    return self.nodify(node, -2147483648)
                return self.nodify(node, l/r)

        return node

    @errorfy
    def visit_Compare(self, node):
        if len(node.ops)!=1 or len(node.comparators)!=1:
            return node
        node.left=self.visit(node.left)
        node.comparators[0]=self.visit(node.comparators[0])
        if self.is_constant(node.left) and self.is_constant(node.comparators[0]):
            l=self.get_val(node.left)
            r=self.get_val(node.comparators[0])
            op=node.ops[0]
            if isinstance(op, ast.Lt):
                return self.nodify(node, l<r)
            elif isinstance(op, ast.LtE):
                return self.nodify(node, l<=r)
            elif isinstance(op, ast.Eq):
                return self.nodify(node, l==r)
            elif isinstance(op, ast.NotEq):
                return self.nodify(node, l!=r)
            elif isinstance(op, ast.Gt):
                return self.nodify(node, l>r)
            elif isinstance(op, ast.GtE):
                return self.nodify(node, l>=r)
        return node

    @errorfy
    def visit_BoolOp(self, node):
        for i in range(len(node.values)):
            node.values[i]=self.visit(node.values[i])
        i=0
        while i+1<len(node.values):
            l=node.values[i]
            r=node.values[i+1]
            if isinstance(node.op, ast.Or):
                if self.is_constant(l):
                    if self.get_val(l):
                        # l is true so we'll never evaluate the rest
                        del node.values[i+1:]
                        i+=1
                        continue
                    else:
                        # l is false so ignore it
                        del node.values[i]
                        continue
            elif isinstance(node.op, ast.And):
                if self.is_constant(l):
                    if self.get_val(l):
                        del node.values[i]
                        continue
                    else:
                        del node.values[i+1:]
            i+=1

        if len(node.values)==1:
            return node.values[0]
        elif len(node.values)==0:
            return self.nodify(node, False)
        return node

    @errorfy
    def visit_Assert(self, node):
        node.test=self.visit(node.test)
        if self.is_constant(node.test):
            if self.get_val(node.test):
                return
        return node

    @errorfy
    def visit_UnaryOp(self, node):
        node.operand=self.visit(node.operand)
        if self.is_constant(node.operand):
            v=self.get_val(node.operand)
            if isinstance(node.op, ast.UAdd):
                return self.nodify(node, +v)
            elif isinstance(node.op, ast.USub):
                return self.nodify(node, -v)
            elif isinstance(node.op, ast.Not):
                return self.nodify(node, not v)
        return node

    @errorfy
    def visit_IfExp(self, node):
        node.test=self.visit(node.test)
        node.body=self.visit(node.body)
        node.orelse=self.visit(node.orelse)
        if self.is_constant(node.test):
            if self.get_val(node.test):
                return node.body
            else:
                return node.orelse
        return node

    @errorfy
    def visit_If(self, node):
        node.test=self.visit(node.test)
        self.process_list(node.orelse)
        self.process_list(node.body)

        if self.is_constant(node.test):
            if self.get_val(node.test):
                return node.body if node.body else None
            return node.orelse if node.orelse else None
        return node

    @errorfy
    def visit_While(self, node):
        if node.orelse:
            return node
        node.test=self.visit(node.test)
        self.process_list(node.body)
        if self.is_constant(node.test) and not self.get_val(node.test):
            return None
        return node

class ConstantExpander(ast.NodeTransformer):

    @errorfy
    def visit_Name(self, node):
        if node.id in self.options.constants:
            if isinstance(node.ctx, ast.Load):
                newnode=ast.parse(self.options.constants[node.id], mode="eval").body
                for n in ast.walk(newnode):
                    ast.copy_location(n, node)
                return newnode
            elif isinstance(node.ctx, ast.Store):
                raise Exception("You can't assign to a constant")
        return node

class AugmentedAssigns(ast.NodeTransformer):
    @errorfy
    def visit_AugAssign(self, node):
        binop=ast.BinOp(op=node.op, left=node.target, right=node.value)
        ast.copy_location(binop, node.value)
        assign=ast.Assign(targets=[node.target], value=binop)
        ast.copy_location(assign, node)
        return assign

# Currently only used by assert.  Turns nodes back into pretty printed
# text.  While we could stringify each node it would require inserting
# lots of parentheses to keep precedence and would look ugly.
# Fortunately we can instead keep track of line and column numbers and
# know that asserts extend to end of line.  You could still fool this
# with multiline strings
class Printer(ast.NodeVisitor):
    def as_string(self, filename, node):
        self.linestart=1000000
        self.colstart=100000
        self.lineend=-1
        self.visit(node)
        res=[]
        for lineno in range(self.linestart, self.lineend+1):
            line=linecache.getline(filename, lineno)
            if lineno==self.linestart:
                line=line[self.colstart:]
            res.append(line)
        return "".join(res).strip()

    def generic_visit(self, node):
        if hasattr(node, "lineno"):
            if node.lineno<=self.linestart:
                self.linestart=node.lineno
                self.colstart=min(node.col_offset, self.colstart)
            if node.lineno>self.lineend:
                self.lineend=node.lineno
        for f in node._fields:
            v=getattr(node, f, None)
            if isinstance(v, list):
                for x in v:
                    self.visit(x)
            elif isinstance(v, ast.AST):
                self.visit(v)

def dumper(options, infile, outfile=sys.stdout):
    data=array.array("B")

    def getnbytes(n):
        (data.fromfile if py3 else data.read)(infile, n)

    # convenience
    def get16():
        getnbytes(2)
        return (data[-1]<<8)+data[-2]

    # version
    ver=get16()
    write("# JMP Version:", ver, file=outfile)
    if ver!=0:
        write("Don't know how to handle version", ver, file=sys.stderr)
        sys.exit(2)

    # string table
    strings=[]
    for i in range(get16()):
        slen=get16()
        if slen>0:
            getnbytes(slen)
            string=data[-slen:].tostring().decode("utf8")
        else:
            string=""
        strings.append(string)

    # line number table
    linenotab=[]
    for i in range(get16()):
        linenotab.append( (get16(), get16()) )

    # code bytes
    codesize=get16()
    getnbytes(codesize)
    if infile.read() not in ("", b""):
        write("There is extra gunk at the end of the file - position", len(data), file=sys.stderr)
    code=data[-codesize:]
    dump_internal(options, strings, linenotab, code, outfile)

def dump_internal(options, strings, linenotab, code, outfile):

    getline=lambda x: linecache.getline(options.dump_source, x) if options.dump_source else ""

    if False:
        # prints tables
        write(">>> String table", file=outfile)
        for i,v in enumerate(strings):
            write("% 4d" % (i,), v)
        write(">>> Line number table", file=outfile)
        for pos,line in linenotab:
            write("% 4d  -> % 4d" % (pos, line), file=outfile)

    # map ints to strings
    bytecode_map={}
    for k,v in opcode_map.items():
        bytecode_map[v]=k

    # print the code
    pos=0
    while pos<len(code):
        for p,l in linenotab:
            if p>pos:
                break
            if p==pos:
                write("\n# % 4d %s" % (l, getline(l)), file=outfile)
                break

        op=code[pos]
        if op>=128:
            val=code[pos+1]+(code[pos+2]<<8)

        if op not in bytecode_map:
            write("Unknown opcode %d at position %d" % (op, pos), file=sys.stderr)
            sys.exit(6)

        if op<128:
            write("% 5d" % pos, bytecode_map[op], file=outfile)
            pos+=1
        else:
            v=bytecode_map[op]
            write("% 5d" % pos, v, val, end="", file=outfile)
            if is_string_op(op):
                write("   \t# string", val, "=", repr(strings[val]), end="", file=outfile)
            write(file=outfile)
            pos+=3

# this works just fine on the Objective-C code too
def java_maintenance(javafile):
    out=[]
    seen=set()
    indent=None

    int_map={}
    for s,v in opcode_map.items():
        int_map[v]=s

    with open(javafile, "rt") as java:
        capturing=False
        for line in java:
            if not capturing and line.lstrip().startswith("// -- check start :"):
                if len(seen):
                    seen=set()
                out.append(line)
                capturing=True
                indent=re.match(r"^(\s*)", line).group(1)
                continue
            if capturing:
                m=re.match(r"\s*case\s+(?P<val>[0-9]+):\s*//\s*(?P<name>\w+)\s*$", line)
                if m:
                    val=int(m.group("val"))
                    if val in seen:
                        write ("Second occurrence of",val,": line", len(out)+1, file=sys.stderr)
                        sys.exit(3)
                    if val not in int_map:
                        write ("Line",len(out)+1,"opcode",val,"/", m.group("name"), "is not known", file=sys.stderr)
                        sys.exit(3)
                    if int_map[val]!=m.group("name"):
                        write ("Line",len(out)+1,"opcode",val,"has comment of",m.group("name"),"but should be", int_map[val], file=sys.stderr)
                        sys.exit(3)
                    seen.add(val)
                if line.lstrip().startswith("// -- check end :"):
                    for val,name in sorted(int_map.items()):
                        if val not in seen:
                            out.append("%scase %d: // %s\n" % (indent, val, name))
                    capturing=False
            out.append(line)


    if capturing:
        write ("End marker not seen in", javafile, file=sys.stderr)
        sys.exit(3)

    if not seen:
        write ("Special markers not identified in", javafile, file=sys.stderr)
        sys.exit(3)

    out="".join(out)
    with open(javafile, "rtU") as java:
        current=java.read()

    if current!=out:
        with open(javafile, "wt") as java:
            java.write(out)
        write("Updated", javafile)



# The operations making up our bytecode.  We use a dict not a list so
# that there can be holes/changes without affecting the stability of
# the numbers.  Codes with the upper bit set take an additional two bytes
# of integer (as is or mapping into string table)
opcode_map={
    "FUNCTION_PROLOG": 0,
    "ADD": 1,
    "MULT": 2,
    "DIV": 3,
    "GT": 4,
    "LT": 5,
    "EQ": 6,
    "IN": 7,
    "UNARY_ADD": 8,
    "RETURN": 9,
    "CALL": 10,
    "POP_TOP": 11,
    "ATTR": 12,
    "UNARY_NEG": 13,
    "SUBSCRIPT": 14,
    "SUBSCRIPT_SLICE": 15,
    "DICT": 16,
    "LIST": 17,
    "PUSH_NONE": 18,
    "EXIT_LOOP": 19,
    "PUSH_TRUE": 21,
    "PUSH_FALSE": 22,
    "PRINT": 23,
    "NOT": 24,
    "ITER": 25,
    "NOT_EQ": 26,
    "SUB": 27,
    "DEL_INDEX": 28,
    "DEL_SLICE": 29,
    "MOD": 30,
    "GTE": 31,
    "LTE": 32,
    "ASSIGN_INDEX": 33,
    "ASSERT_FAILED": 34,
    "IS": 35,

    ### All the following take two byte int value
    # code position
    "MAKE_METHOD": 128,
    "GOTO": 129,
    "IF_FALSE": 130,
    "NEXT": 131,
    "AND": 132,
    "OR": 133,

    # strings
    "LOAD_NAME": 160,
    "STORE_NAME": 161,
    "PUSH_STR": 162,
    "GLOBAL": 163,
    "DEL_NAME": 164,
    "STORE_ATTR_NAME": 165,

    # integers
    "PUSH_INT": 200,
    "PUSH_INT_HI": 201,
    "POP_N": 202,
}

def is_code_op(v):
    return 128<=v<160

def is_string_op(v):
    return 160<=v<200

def is_int_op(v):
    return 200<=v


if __name__=='__main__':
    import optparse
    p=optparse.OptionParser(usage="""%prog [options] file[s]""")

    g=optparse.OptionGroup(p, "Compiling", """Supply a filename
containing Python code.  By default the output is written
alongside the input file with a .jmp extension.  Supply a second
filename to override where the output goes.""")

    g.add_option("--traceback", action="store_true", default=False, help="Show detailed traceback on compile error")
    g.add_option("--print-function", action="store_true", default=sys.version_info>(3,), help="'print' is a function (Python 3 style) [%default]")
    g.add_option("--asserts", action="store_true", default=False, help="Include asserts in the output [%default]")
    g.add_option("--annotate", action="store_true", default=False, help="Place an annotated JMP/source file (extension .i) alongside the .jmp output")
    g.add_option("--syntax", action="store_false", dest="jmpoutput", default=True, help="Do a syntax check only - do not produce .jmp output")
    g.add_option("--omit-line-table", action="store_false", default=True, dest="line_table", help="Exclude line number table from .jmp output")
    g.add_option("--no-optimization", action="store_false", default=True, dest="optimization", help="Turn off optimizations")
    g.add_option("--constant", action="append", dest="constants", help="name=value Can be supplied multiple times.  value must be in Python syntax and can be any expression.  For example --constant DEBUG=True\t --constant VERSION=\"2alpha1\"\t --constant mapping={3: \"Three\", 4: \"a\"*7m 5: func(1,2,3)}.  Be careful of the shell altering what you provide on the command line.")
    p.add_option_group(g)

    g=optparse.OptionGroup(p, "Dumping", """Shows the internal representation of .jmp files""")
    g.add_option("--dump", action="store_true", default=False, help="Enables dump mode")
    g.add_option("--dump-source", metavar="INPUT.PY", help="Optional source file that the .jmp was compiled from.  The lines of the source will be intermingled with the output.  By default a .py file alongside the .jmp is used if present.")
    p.add_option_group(g)

    p.add_option("--java-maintenance", action="store_true", help=optparse.SUPPRESS_HELP)

    p.disable_interspersed_args()
    options, args=p.parse_args()

    constants={}
    if options.constants:
        for c in options.constants:
            expr=c.split("=", 1)
            if len(expr)!=2:
                p.error("You need to provide an equals and a value for constant "+repr(c))
            name, value=expr
            name=name.strip()
            value=value.strip()
            try:
                compile(value, value, "eval")
            except:
                _,e,_=sys.exc_info()
                p.error("Error in constant '%s' value '%s': %s" % (name, value, str(e)))
            constants[name]=value
    options.constants=constants

    if options.dump:
        if len(args)!=1:
            p.error("--dump: One argument expected")
        if not options.dump_source:
            fn=os.path.splitext(args[0])[0]+".py"
            if os.path.exists(fn):
                options.dump_source=fn
        if options.dump_source:
            if not os.path.exists(options.dump_source):
                p.error("Source file \"%s\" not found" % (options.source,))
            # ::TODO:: whine if source is newer than .jmp?
        dumper(options, open(args[0], "rb"))
    elif options.java_maintenance:
        if len(args)!=1:
            p.error("--java-maintenance: One argument expected")
        java=args[0]
        if not os.path.isfile(java):
            p.error(java+" doesn't exist")
        java_maintenance(java)
    else:
        if len(args) not in (1,2):
            p.error("One or two filenames are expected, not "+str(len(args)))
        c=Compiler()
        c.compile(options, args[0], args[1] if len(args)==2 else os.path.splitext(args[0])[0]+".jmp")
